# encoding: utf-8
"""
@Time:2022/3/21 16:01
@Author: shujin sun
@Desc: 缺陷检测测试
"""

import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt
import math
import sys

# 中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']  # 显示中文
plt.rcParams['axes.unicode_minus'] = False  # 显示坐标轴负号


class DetectDefect:
    def __init__(self):
        self.little_img_file = ''
        self.big_img_file = ''

        self.big_img = np.array(0)
        self.little_img = np.array(0)
        self.roi_img = np.array(0)
        # 缺陷检测二值化结果
        self.bw_img = np.array(0)

        # big image 的大小
        self.w = 0
        self.h = 0
        # 重点关注区域，即中间区域
        self.roi_rect = [0, 0, 0, 0]
        self.roi_patchs = []
        # 缺陷检测阈值
        self.thresh_value = 3000

        # 所有的匹配方法
        self.match_methods = ['cv.TM_CCOEFF',  # 相关系数匹配法
                              'cv.TM_CCOEFF_NORMED',  # 归一化版本
                              'cv.TM_CCORR',  # 相关匹配法
                              'cv.TM_CCORR_NORMED',
                              'cv.TM_SQDIFF',  # 平方差匹配法
                              'cv.TM_SQDIFF_NORMED']

    def set_little_image(self, _file):
        self.little_img_file = _file
        self.little_img = cv.imread(_file)# cv.IMREAD_GRAYSCALE

    def set_big_image(self, _file):
        self.big_img_file = _file
        self.big_img = cv.imread(_file)
        print(self.big_img.shape)
        self.h, self.w = self.big_img.shape[:2]
        print('big image shape: ', self.w, self.h)

    def set_thresh_value(self, thresh_value):
        """设置缺陷检测阈值门限"""
        self.thresh_value = thresh_value

    def conduct_edge(self, src):
        '''边缘检测'''
        # 计算像素中位数
        gray_img = cv.cvtColor(src, cv.COLOR_BGR2GRAY)
        median_intensity = np.median(gray_img)

        lower_threshold = int(max(0, (1.0 - 0.33) * median_intensity))
        upper_threshold = int(min(255, (1.0 + 0.33) * median_intensity))

        canny_img = cv.Canny(gray_img, lower_threshold, upper_threshold)

        plt.imshow(canny_img, cmap='gray')
        plt.title('canny edge detect result')
        plt.xticks([]), plt.yticks([])
        plt.show()

    def conduct_binary_adaptive(self, src):
        """二值化操作,自适应方法"""
        src = cv.cvtColor(src, cv.COLOR_BGR2GRAY)
        bw_img = cv.adaptiveThreshold(src,
                                      255,
                                      cv.ADAPTIVE_THRESH_MEAN_C,#cv.ADAPTIVE_THRESH_GAUSSIAN_C,
                                      cv.THRESH_BINARY,
                                      5,
                                      7)
        return bw_img

    def conduct_binary(self, src):
        """二值化操作"""
        src = cv.cvtColor(src, cv.COLOR_BGR2GRAY)
        mean_intensity = np.mean(src)
        median_intensity = np.median(src)
        # print('图像灰度均值、中值：', mean_intensity, median_intensity)
        _, bw_img = cv.threshold(src, mean_intensity, 255, cv.THRESH_BINARY_INV+cv.THRESH_OTSU)
        return bw_img

    def match_template(self, big, little):
        img = big.copy()
        # eval函数，将字符串作为脚本代码来执行
        idx = 3  # 1 or 3方法效果最好
        print('选择的匹配算法：', self.match_methods[idx])
        method = eval(self.match_methods[idx])
        res = cv.matchTemplate(img, little, method)

        min_val, max_val, min_loc, max_loc = cv.minMaxLoc(res)
        top_left = max_loc
        h, w = little.shape[:2]
        bottom_right = (top_left[0] + w, top_left[1] + h)
        return top_left, bottom_right

    def split_patches(self):
        """异常检测
        1.两种choice匹配
        choice1：选择所有匹配算法计算，通过比较评估值，计算最优匹配结果
        choice2：仅挑选一种匹配算法计算，通过前期测试验证，事先固化最稳健的算法
        另外，可考虑图像偏转一定角度（极小范围，可多个值，该角度值若能用算法估算更佳），旋转后再进行匹配、检测

        2.上下扩充
        在roi区域中匹配结束之后，上下进行扩充选择，进行缺陷区域检测，防止漏网之鱼

        3.二次匹配、检测
        在结果中按照固定尺寸分割小块，随机选择1块，or二值化求取面积最大快作为模板块（或者在little_img中截取）

        4.通过模板块，在其他区域进行相关或者相减计算（考虑先进行二值化），根据一定阈值，给出是否缺陷判决

        在roi块中匹配结束之后，再转换至整体图像中位置（视需求而定，忽视此操作也可）
        """
        top_left, bottom_right = self.match_template(self.roi_img, self.little_img)
        print('roi中间结果：', top_left, bottom_right)

        # 通过上下平移获取所有可能的匹配结果
        h = self.little_img.shape[0]  # 小图高度
        n_upper = -int(top_left[1] / h)
        n_bottom = int((self.roi_img.shape[0] - bottom_right[1]) / h)
        patch_arrays = [] # 保存切片在ROI中的位置

        tmp_img = self.roi_img.copy()
        for i in range(n_upper, n_bottom + 1):
            # 上下拓展，包括自身
            t = top_left[1] + i * h
            b = bottom_right[1] + i * h
            if i != 0:
                t = t + int(i / abs(i))
                b = b + int(i / abs(i))
            print('i={0},范围：{1}-{2}'.format(i, t, b))

            # 截取图像切片
            patch_arrays.append([t, b, top_left[0], bottom_right[0]])
            src = tmp_img[t:b, top_left[0]:bottom_right[0]]

            # 保存为文件
            # save_name = '{0}_slice_{1}.bmp'.format(big_img_file[:-4], i)
            # print(src.shape, save_name)
            # cv.imwrite(save_name, src)

            # 显示中间匹配结果
            cv.rectangle(tmp_img, (top_left[0], t), (bottom_right[0], b), 255, 2)

        self.roi_patchs = np.array(patch_arrays)
        print('切片位置：', self.roi_patchs, self.roi_patchs.shape)
        plt.imshow(tmp_img, cmap='gray')
        plt.title('ROI匹配中间结果')
        plt.show()

    def set_roi_rect(self, ratio_left=0.27, ratio_right=0.82, ratio_upper=0.2, ratio_bottom=0.8):
        '''设置中间关注区域，左右上下位置的ratio比例'''
        self.roi_rect[0] = int(self.w * ratio_left)
        self.roi_rect[1] = int(self.w * ratio_right)
        self.roi_rect[2] = int(self.h * ratio_upper)
        self.roi_rect[3] = int(self.h * ratio_bottom)

    def show_roi_img(self):
        """显示中间roi图像"""
        print('roi:', self.roi_rect)
        self.roi_img = self.big_img[self.roi_rect[2]:self.roi_rect[3], self.roi_rect[0]:self.roi_rect[1],:]
        # 保存中间roi图像
        save_name = '{0}_roi.bmp'.format(big_img_file[:-4])
        cv.imwrite(save_name, self.roi_img)

        tmp_img = self.big_img.copy()
        top_left = (self.roi_rect[1], self.roi_rect[3])
        bottom_right = (self.roi_rect[0], self.roi_rect[2])
        cv.rectangle(tmp_img, top_left, bottom_right, 255, 2)
        plt.subplot(131), plt.title('Big image')
        plt.imshow(tmp_img)
        plt.subplot(132), plt.title('roi image')
        plt.imshow(self.roi_img, cmap='gray')
        # plt.xticks([]), plt.yticks([])
        plt.subplot(133), plt.title('little image')
        plt.imshow(self.little_img, cmap='gray')

        plt.show()

    def show_defect_image(self):
        '''显示缺陷检测结果图像'''
        tmp_img = self.big_img.copy()
        top_left = (50, 50)
        bottom_right = (100, 100)
        cv.rectangle(tmp_img, top_left, bottom_right, (255, 0, 0), 4)

        plt.imshow(tmp_img)
        plt.title('缺陷检测结果')
        plt.show()

    def defect_analyze_patchs(self):
        '''对切片逐个进行二值化和连通分析，返回连通分析结果'''
        is_defected = False
        img_output = self.big_img.copy()
        n_patches = self.roi_patchs.shape[0]
        roi_img_output = self.roi_img.copy()
        for i in range(0, n_patches):
            pos = self.roi_patchs[i, :]
            # patch 在roi中的位置,按照左右上下顺序排列
            path_in_roi = pos[2], pos[0]
            print('patch 在ROI中位置：', path_in_roi)
            patch = self.roi_img[pos[0]:pos[1], pos[2]:pos[3]]
            # 每个小方块的检测结果
            result, width_square = self.square_detect_in_patch(patch)
            for j in range(0, result.shape[0]):
                if result[j, 0] == 1:
                    is_defected = True
                    topleft = int(j*width_square + result[j, 1]), int(result[j, 2])
                    bottomright = int(topleft[0] + result[j, 3]), int(result[j, 4])
                    print('patch 中位置：', topleft, bottomright)

                    # 在切片上画图
                    # cv.rectangle(patch, topleft, bottomright, (255, 0, 0), 10)

                    # 将位置转换至ROI图像上, x,y,width,height
                    pos_in_patch = topleft[0], topleft[1], result[j, 3], result[j, 4]
                    pos_in_roi = self.cvt_pos_to_roi(pos_in_patch, path_in_roi)

                    # 在ROI上画图
                    topleft, bottomright = self.cvt_top_bottom(pos_in_roi[0],
                                                               pos_in_roi[1],
                                                               pos_in_roi[2],
                                                               pos_in_roi[3])
                    print('ROI中位置：', topleft, bottomright)
                    cv.rectangle(roi_img_output, topleft, bottomright, (255, 0, 0), 5)

                    # 将位置转换至整个图像上
                    pos_in_img = self.cvt_pos_to_image(pos_in_patch, path_in_roi)
                    topleft, bottomright = self.cvt_top_bottom(pos_in_img[0],
                                                               pos_in_img[1],
                                                               pos_in_img[2],
                                                               pos_in_img[3])

                    cv.rectangle(img_output, topleft, bottomright, (255, 0, 0), 5)

        plt.subplot(121), plt.imshow(roi_img_output)
        plt.title('ROI区域')
        plt.subplot(122), plt.imshow(img_output)
        title = '有缺陷' if is_defected else '正常'
        plt.title(title)
        plt.show()

    def cvt_top_bottom(self, x, y, width, height):
        '''将x,y,w,h转换为矩阵的topleft和bottomright'''
        topleft = x, y
        bottomright = x + width, y + height

        return topleft, bottomright

    def cvt_pos_to_roi(self, pos_in_patch, patch_in_roi):
        '''将patch坐标转换至roi上'''
        pos_in_roi= []

        x = pos_in_patch[0] + patch_in_roi[0]
        y = pos_in_patch[1] + patch_in_roi[1]

        width = pos_in_patch[2]
        height = pos_in_patch[3]

        pos_in_roi = int(x), int(y), int(width), int(height)

        return pos_in_roi

    def cvt_pos_to_image(self, pos_in_patch, patch_in_roi):
        '''将patch坐标转换至image上'''
        pos_in_image = []

        x = pos_in_patch[0] + patch_in_roi[0] + self.roi_rect[0]
        y = pos_in_patch[1] + patch_in_roi[1] + self.roi_rect[2]

        width = pos_in_patch[2]
        height = pos_in_patch[3]

        pos_in_image = int(x), int(y), int(width), int(height)

        return pos_in_image

    def white_ratio_analyze(self, src):
        """白色比例分析"""
        w, h = src.shape
        # print("ratio分析：", w, h, src.sum(), src.max())
        ratio = src.sum() / (w * h * 255)
        return ratio

    def connect_component(self, src_in):
        """连通区域分析"""
        src = src_in.copy()
        n, labels, stats, centroids = cv.connectedComponentsWithStats(src, connectivity=8)
        # print('----' * 20, "\n", stats)

        # 计算面积 同 长宽比值，或者面积最大值
        # 此处直接使用面积最大值
        # stats：x y width, height, area，行数为n,首行为背景
        array_ratio = stats[1:, -1]

        # print('剔除背景之后：', array_ratio)
        idx = array_ratio.argmax() + 1
        max_ratio = array_ratio.max()
        # print('像素面积：', stats[:, 4], sum(stats[:, 4]))
        # print('面积最大值：', max_ratio, idx)

        # 根据像素点计算实际面积
        mask1 = labels == idx
        # print('实际像素个数:', sum(sum(mask1)))

        # 不同的连通域赋予不同的颜色
        output = np.zeros((labels.shape[0], labels.shape[1], 3), np.uint8)
        for i in range(1, n):
            mask = labels == i
            output[:, :, 0][mask] = np.random.randint(0, 255)
            output[:, :, 1][mask] = np.random.randint(0, 255)
            output[:, :, 2][mask] = np.random.randint(0, 255)

        # 矩形区域
        x, y, width, height, areas = stats[idx, :]
        topleft = x, y
        bottomright = x + width, y + height
        cv.rectangle(output, topleft, bottomright, (255, 255, 0), 2, 1, 0)
        # 中心点
        center = int(centroids[idx, 0]), int(centroids[idx, 1])
        cv.circle(output, center, 5, (255, 0, 0), 1)
        # cv.putText(output,"{0}".format(areas),center,cv.FONT_HERSHEY_SIMPLEX,0.5,(255,0,0),1)

        # for idx_i in range(0,n):
        #     # 矩形区域
        #     x, y, width, height, areas = stats[idx_i, :]
        #     print(idx_i, '循环:', areas)
        #     topleft = x, y
        #     bottomright = x + width, y + height
        #     #cv.rectangle(output, topleft, bottomright, (255, 255, 0), 1, 1, 0)
        #     # 中心点
        #     center = int(centroids[idx_i, 0]), int(centroids[idx_i, 1])
        #     cv.circle(output, center, 5, (255, 0, 0), 1)
        #     # cv.putText(output,"{0}".format(areas),center,cv.FONT_HERSHEY_SIMPLEX,0.5,(255,0,0),1)

        return output, stats[idx, :]

    def square_detect_in_patch(self, patch):
        # ROI图像中共11块
        N = 11
        result = np.zeros((N, 6))
        h, w = patch.shape[:2]
        one_w = int(w / N)
        for i in range(0, N):
            # print(f'----------------------分块序号：{i}---------------------')
            # 抽取块，然后计算相关系数
            tmp_img = patch[:, i * one_w + 3:(i + 1) * one_w + 3]
            bw_img = self.conduct_binary(tmp_img)

            # 形态学滤波
            bw_img_in = bw_img.copy()
            kernel = cv.getStructuringElement(cv.MORPH_RECT, (5, 5))
            morph_open_img = cv.morphologyEx(bw_img_in, cv.MORPH_OPEN, kernel, 3)
            # morph_close_img = cv.morphologyEx(bw_img, cv.MORPH_CLOSE, kernel, 3)

            # 白色比例分析
            ratio_value = self.white_ratio_analyze(morph_open_img)

            # 连通区域分析，这里可设置阈值判断是否为缺陷区域
            morph_open_img_label, stat = self.connect_component(morph_open_img)
            # print('分块序号：{0}，最大区域：{1}'.format(i, stat))
            if stat[4] > self.thresh_value:
                result[i, 0] = 1
                result[i, 1:] = stat

        return result, one_w


if __name__ == '__main__':
    # 匹配文件名称
    template_img_file = sys.argv[1].replace('\\','/')
    big_img_file = sys.argv[2].replace('\\','/')
    # print(template_img_file)
    # print(big_img_file)
    # print(int(sys.argv[3]))
    detector = DetectDefect()
    detector.set_big_image(big_img_file)
    detector.set_little_image(template_img_file)
    detector.set_roi_rect()
    # detector.set_thresh_value(thresh_value=6500) # 2000, 6500
    detector.set_thresh_value(thresh_value=int(sys.argv[3]))  # 2000, 6500
    detector.show_roi_img()
    detector.split_patches()
    detector.defect_analyze_patchs()
